{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "es57M8iFDhHD"
      },
      "source": [
        "##### Copyright 2021 Google LLC.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "oqudn4pXDmTO"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CVpl5q-NMBOp"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/projects/radiance_fields/tiny_nerf.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/projects/radiance_fields/tiny_nerf.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vcpOzSMjDrm2"
      },
      "source": [
        "# Setup and imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_oTT3WZLA_SS"
      },
      "outputs": [],
      "source": [
        "%pip install tensorflow_graphics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DdWAz8P3BGJ1"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "import tensorflow.keras.layers as layers\n",
        "\n",
        "import tensorflow_graphics.projects.radiance_fields.data_loaders as data_loaders\n",
        "import tensorflow_graphics.projects.radiance_fields.utils as utils\n",
        "import tensorflow_graphics.rendering.camera.perspective as perspective\n",
        "import tensorflow_graphics.geometry.representation.ray as ray\n",
        "import tensorflow_graphics.math.feature_representation as feature_rep\n",
        "import tensorflow_graphics.rendering.volumetric.ray_radiance as ray_radiance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bgAPuVhmDztJ"
      },
      "source": [
        "Please download the data from the original [repository](https://github.com/bmild/nerf). In this tutorial we experimented with the synthetic data (lego, ship, boat, etc) that can be found [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Then, you can either point to them locally (if you run a custom kernel) or upload them to the google colab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TyM9pwsCBWpO"
      },
      "outputs": [],
      "source": [
        "DATASET_DIR = '/content/nerf_synthetic/'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HOHVUtbsBOPV"
      },
      "outputs": [],
      "source": [
        "#@title Parameters\n",
        "\n",
        "batch_size = 10 #@param {type:\"integer\"}\n",
        "n_posenc_freq = 6 #@param {type:\"integer\"}\n",
        "learning_rate = 0.0005 #@param {type:\"number\"}\n",
        "n_filters = 256 #@param {type:\"integer\"}\n",
        "\n",
        "\n",
        "num_epochs = 100 #@param {type:\"integer\"}\n",
        "n_rays = 512 #@param {type:\"integer\"}\n",
        "near = 2.0 #@param {type:\"number\"}\n",
        "far = 6.0 #@param {type:\"number\"}\n",
        "ray_steps = 64 #@param {type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D6WyAkrUKsWM"
      },
      "source": [
        "# Training a NeRF network"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XNrVXb01BeMb"
      },
      "outputs": [],
      "source": [
        "#@title Load the lego dataset { form-width: \"350px\" }\n",
        "\n",
        "dataset, height, width = data_loaders.load_synthetic_nerf_dataset(\n",
        "    dataset_dir=DATASET_DIR,\n",
        "    dataset_name='lego',\n",
        "    split='train',\n",
        "    scale=0.125,\n",
        "    batch_size=batch_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jeLkVPqdBgZB"
      },
      "outputs": [],
      "source": [
        "#@title Prepare the NeRF model and optimizer { form-width: \"350px\" }\n",
        "\n",
        "input_dim = n_posenc_freq * 2 * 3 + 3\n",
        "\n",
        "\n",
        "def get_model():\n",
        "    \"\"\"Tiny NeRF network.\"\"\"\n",
        "    with tf.name_scope(\"Network/\"):\n",
        "      input_features = layers.Input(shape=[input_dim])\n",
        "      fc0 = layers.Dense(n_filters, activation=layers.ReLU())(input_features)\n",
        "      fc1 = layers.Dense(n_filters, activation=layers.ReLU())(fc0)\n",
        "      fc2 = layers.Dense(n_filters, activation=layers.ReLU())(fc1)\n",
        "      fc3 = layers.Dense(n_filters, activation=layers.ReLU())(fc2)\n",
        "      fc4 = layers.Dense(n_filters, activation=layers.ReLU())(fc3)\n",
        "      fc4 = layers.concatenate([fc4, input_features], -1)\n",
        "      fc5 = layers.Dense(n_filters, activation=layers.ReLU())(fc4)\n",
        "      fc6 = layers.Dense(n_filters, activation=layers.ReLU())(fc5)\n",
        "      fc7 = layers.Dense(n_filters, activation=layers.ReLU())(fc6)\n",
        "      rgba = layers.Dense(4)(fc7)\n",
        "      return tf.keras.Model(inputs=[input_features], outputs=[rgba])\n",
        "\n",
        "\n",
        "model = get_model()\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4LqCqUNOK0z-"
      },
      "outputs": [],
      "source": [
        "# @title Set up the training procedure { form-width: \"350px\" }\n",
        "\n",
        "@tf.function\n",
        "def network_inference_and_rendering(ray_points, model):\n",
        "  \"\"\"Render the 3D ray points into rgb pixels.\n",
        "\n",
        "  Args:\n",
        "    ray_points: A tensor of shape `[A, B, C, 3]` where A is the batch size,\n",
        "      B is the number of rays, C is the number of samples per ray.\n",
        "    model: the NeRF model to run\n",
        "\n",
        "  Returns:\n",
        "    Two tensors of size `[A, B, 3]`.\n",
        "  \"\"\"\n",
        "  features_xyz = feature_rep.positional_encoding(ray_points, n_posenc_freq)\n",
        "  features_xyz = tf.reshape(features_xyz, [-1, tf.shape(features_xyz)[-1]])\n",
        "  rgba = model([features_xyz])\n",
        "  target_shape = tf.concat([tf.shape(ray_points)[:-1], [4]], axis=-1)\n",
        "  rgba = tf.reshape(rgba, target_shape)\n",
        "  rgb, alpha = tf.split(rgba, [3, 1], axis=-1)\n",
        "  rgb = tf.sigmoid(rgb)\n",
        "  alpha = tf.nn.relu(alpha)\n",
        "  rgba = tf.concat([rgb, alpha], axis=-1)\n",
        "  dists = utils.get_distances_between_points(ray_points)\n",
        "  rgb_render, _, _ = ray_radiance.compute_radiance(rgba, dists)\n",
        "  return rgb_render\n",
        "\n",
        "\n",
        "@tf.function\n",
        "def train_step(ray_origin, ray_direction, gt_rgb):\n",
        "  \"\"\"Training function for coarse and fine networks.\n",
        "\n",
        "  Args:\n",
        "    ray_origin: A tensor of shape `[A, B, 3]` where A is the batch size,\n",
        "      B is the number of rays.\n",
        "    ray_direction: A tensor of shape `[A, B, 3]` where A is the batch size,\n",
        "      B is the number of rays.\n",
        "    gt_rgb: A tensor of shape `[A, B, 3]` where A is the batch size,\n",
        "      B is the number of rays.\n",
        "\n",
        "  Returns:\n",
        "    A scalar.\n",
        "  \"\"\"\n",
        "  with tf.GradientTape() as tape:\n",
        "    ray_points, _ = ray.sample_1d(\n",
        "        ray_origin,\n",
        "        ray_direction,\n",
        "        near=near,\n",
        "        far=far,\n",
        "        n_samples=ray_steps,\n",
        "        strategy='stratified')\n",
        "\n",
        "    rgb = network_inference_and_rendering(ray_points, model)\n",
        "    total_loss = utils.l2_loss(rgb, gt_rgb)\n",
        "  gradients = tape.gradient(total_loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "  return total_loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cMJA2RVxLC0-"
      },
      "outputs": [],
      "source": [
        "for epoch in range(0, num_epochs):\n",
        "  epoch_loss = 0.0\n",
        "  for image, focal, principal_point, transform_matrix in dataset:\n",
        "    # Prepare the rays\n",
        "    random_rays, random_pixels_xy = perspective.random_rays(focal,\n",
        "                                                            principal_point,\n",
        "                                                            height,\n",
        "                                                            width,\n",
        "                                                            n_rays)\n",
        "    # TF-Graphics camera rays to NeRF world rays\n",
        "    random_rays = utils.change_coordinate_system(random_rays,\n",
        "                                                 (0., 0., 0.),\n",
        "                                                 (1., -1., -1.))\n",
        "    rays_org, rays_dir = utils.camera_rays_from_transformation_matrix(\n",
        "        random_rays,\n",
        "        transform_matrix)\n",
        "    random_pixels_yx = tf.reverse(random_pixels_xy, axis=[-1])\n",
        "    pixels = tf.gather_nd(image, random_pixels_yx, batch_dims=1)\n",
        "    pixels_rgb, _ = tf.split(pixels, [3, 1], axis=-1)\n",
        "    dist_loss = train_step(rays_org, rays_dir, pixels_rgb)\n",
        "    epoch_loss += dist_loss\n",
        "  print('Epoch {0} loss: {1:.3f}'.format(epoch, epoch_loss))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CyulVmShLNbp"
      },
      "source": [
        "# Testing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DnjfqUG9LGmh"
      },
      "outputs": [],
      "source": [
        "# @title Load the test data\n",
        "\n",
        "test_dataset, height, width = data_loaders.load_synthetic_nerf_dataset(\n",
        "    dataset_dir=DATASET_DIR,\n",
        "    dataset_name='lego',\n",
        "    split='val',\n",
        "    scale=0.125,\n",
        "    batch_size=1,\n",
        "    shuffle=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ttXv6CZKMSVF"
      },
      "outputs": [],
      "source": [
        "for testimg, focal, principal_point, transform_matrix in test_dataset.take(1):\n",
        "  testimg = testimg[0, :, :, :3]\n",
        "\n",
        "  img_rays, _ = perspective.random_patches(\n",
        "      focal,\n",
        "      principal_point,\n",
        "      height,\n",
        "      width,\n",
        "      patch_height=height,\n",
        "      patch_width=width,\n",
        "      scale=1.0)\n",
        "\n",
        "  # Break the test image into lines, so we don't run out of memory\n",
        "  batch_rays = tf.split(img_rays, height, axis=1)\n",
        "  output = []\n",
        "  for random_rays in batch_rays:\n",
        "    random_rays = utils.change_coordinate_system(random_rays,\n",
        "                                                  (0., 0., 0.),\n",
        "                                                  (1., -1., -1.))\n",
        "    rays_org, rays_dir = utils.camera_rays_from_transformation_matrix(\n",
        "        random_rays,\n",
        "        transform_matrix)\n",
        "    ray_points, _ = ray.sample_1d(\n",
        "        rays_org,\n",
        "        rays_dir,\n",
        "        near=near,\n",
        "        far=far,\n",
        "        n_samples=ray_steps,\n",
        "        strategy='stratified')\n",
        "    rgb = network_inference_and_rendering(ray_points, model)\n",
        "    output.append(rgb)\n",
        "  final_image = tf.concat(output, axis=0)\n",
        "\n",
        "  fig, ax = plt.subplots(1, 2)\n",
        "  ax[0].imshow(final_image)\n",
        "  ax[1].imshow(testimg)\n",
        "  plt.show()\n",
        "  loss = tf.reduce_mean(tf.square(final_image - testimg))\n",
        "  psnr = -10. * tf.math.log(loss) / tf.math.log(10.)\n",
        "  print(psnr.numpy())"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "TFG-tiny_nerf.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
